# -*- coding: utf-8 -*-

import os, re, json, traceback
from random import shuffle
import cv2
from collections import defaultdict

img_dir = "./WIDER_train"

img_count = 0
file_list = []
for root, dirs, files in os.walk(img_dir):
    for file in files:
        img_count += 1
        file_list.append(os.path.join(root, file))

print("Total image number: %d" % img_count)

# make directory
if not os.path.exists("./human_face_train_images"):
    os.system("mkdir ./human_face_train_images")
if not os.path.exists("./human_face_train_labels"):
    os.system("mkdir ./human_face_train_labels")
if not os.path.exists("./human_face_val_images"):
    os.system("mkdir ./human_face_val_images")
if not os.path.exists("./human_face_val_labels"):
    os.system("mkdir ./human_face_val_labels")

# shuffle the files
shuffle(file_list)

# get label data
with open("./wider_face_train_bbx_gt.txt", "r", encoding="utf-8") as h:
    content = [_.strip() for _ in h.readlines()]

# get labeled data into arrange form
line_index = []
for i, line in enumerate(content):
    if "." in line:
        line_index.append(i)

line_index.append(len(content) + 1)

segments = []
for j in range(len(line_index) - 1):
    segments.append(content[line_index[j]: line_index[j + 1]])

img_box_dict = defaultdict(list)
for segment in segments:
    for i in range(2, len(segment)):
        img_box_dict[segment[0].split('/')[-1]].append(segment[i].split()[:4])

# copy images to rights place and write correct labeled data into txt file
# train data
train_part = 0.8
for i in range(int(train_part * img_count)):
    print(i, file_list[i])
    file = file_list[i].split('/')[-1]
    os.system("cp %s ./human_face_train_images/%s" % (file_list[i], file))
    with open("./human_face_train.txt", "a", encoding="utf-8") as f:
        f.write("./human_face_train_images/%s" % file + "\n")

    img = cv2.imread(file_list[i], 0)
    height, width = img.shape
    with open("./human_face_train_labels/%s" % file.replace(".jpg", ".txt"), "w", encoding="utf-8") as f:
        for label in img_box_dict[file]:
            left, top, w, h = [int(_) for _ in label]
            # to avoid any of the coordinate becomes 0
            if left == 0:
                left = 0.1
            if top == 0:
                top = 0.1
            if w == 0:
                w = 0.1
            if h == 0:
                h = 0.1
            x_center = (left + w / 2) / width
            y_center = (top + h / 2) / height
            f.write("0 %s %s %s %s\n" % (x_center, y_center, w / width, h / height))

# val data
for i in range(int(train_part * img_count) + 1, img_count):
    print(i, file_list[i])
    file = file_list[i].split('/')[-1]
    os.system("cp %s ./human_face_val_images/%s" % (file_list[i], file))
    with open("./human_face_val.txt", "a", encoding="utf-8") as f:
        f.write("./human_face_val_images/%s" % file + "\n")

    img = cv2.imread(file_list[i], 0)
    height, width = img.shape
    with open("./human_face_val_labels/%s" % file.replace(".jpg", ".txt"), "w", encoding="utf-8") as f:
        for label in img_box_dict[file]:
            left, top, w, h = [int(_) for _ in label]
            # to avoid any of the coordinate becomes 0
            if left == 0:
                left = 0.1
            if top == 0:
                top = 0.1
            if w == 0:
                w = 0.1
            if h == 0:
                h = 0.1
            x_center = (left + w / 2) / width
            y_center = (top + h / 2) / height
            f.write("0 %s %s %s %s\n" % (x_center, y_center, w / width, h / height))